import sys
import time
import timer
import traceback
from copy import copy

import pyrealsense2 as rs
import numpy as np
import cv2
from camera_utils import set_dev_preset, retrieve_aligned_pipeline, get_images_from_pipeline
from detection import find_pts, reduce_pts, find_grasper, find_object_ransac, find_pts_grasper
from geometry import create_xyz, draw_corners, set_intrinsics, draw_grasper, draw_object
from multiprocessing.dummy import Queue, Event, Process as Thread, Value
from multiprocessing import Process, Queue as ProcessQueue, Event as ProcessEvent


# from multiprocessing import Pool


intrinsics = {'fx': 923.696, 'fy': 923.115, 'height': 720, 'width': 1280, 'cx': 630.634, 'cy': 367.255}

colors = {'red': [79.46608003, 173.12259944, 159.88620013], 'green': [49.37390089, 114.06434519, 131.26390264],
          'yellow': [147.95821667, 131.98808242, 189.42594587],
          'blue': [48.36433141, 129.38893223, 107.09408644]}  # , 'blue': [ 66.53376688, 147.5937969 ,  92.63081541]}
color_devs = {'red': [12.45986815, 5.20190592, 4.17891806], 'green': [9.3610654, 2.21609464, 1.59495349],
              'yellow': [28.13293378, 4.28111556, 9.63940965], 'blue': [9.6500446, 1.58701702, 3.78901351]}

grasper_color = [184.13556338,  69.29401408, 178.52570423]
grasper_dev = [17.62697205,  4.28695607,  3.82532902]

# correction = [0.020, 0, 0]
correction = [0, 0, 0]


def capture_thread_target(q_out, e_stop, width, height):
    print("Capture thread started.")
    set_dev_preset("presets/ShortRangePreset.json")
    # set_dev_preset("presets/HighResHighAccuracyPreset.json")

    pipeline, align = retrieve_aligned_pipeline(width=width, height=height, verbose=True)

    try:
        for _ in range(5):
            _ = pipeline.wait_for_frames()

        while (True):
            color_image, depth_image = get_images_from_pipeline(pipeline, align)
            q_out.put([color_image, depth_image], timeout=300)
    finally:
        e_stop.set()
        pipeline.stop()


def capture_thread_from_bag_target(q_out, e_stop, filename, width, height):
    print("Capture thread started.")
    pipeline, align = retrieve_aligned_pipeline(filename = filename, width=width, height=height, verbose=True)

    try:
        while (True):
            color_image, depth_image = get_images_from_pipeline(pipeline, align)
            q_out.put([color_image, depth_image], timeout=300)
    finally:
        e_stop.set()
        pipeline.stop()


def capture_from_file_thread_target(q_out, e_stop):
    print("Capture from file thread started.")
    for i in range(6, 200):
        try:
            ar = np.load("collect/{}.npy".format(i))
            rgb = ar[:, :, :3].astype('uint8')
            depth = ar[:, :, 3]
            q_out.put([rgb, depth], timeout=300)
        except:
            e_stop.set()
    e_stop.set()


def preprocess_thread_target(q_in, q_bgr_img_out, q_grasper_in, q_cubes_out_dict, e_stop):
    print("Preprocess thread started.")
    while (not e_stop.is_set()):
        try:
            color_image, depth_image = q_in.get(timeout=300)
            bgr_image = copy(color_image[:, :, ::-1])
            xyz = create_xyz(depth_image)
            # color_image = cv2.GaussianBlur(color_image, (7, 7), sigmaX=1, sigmaY=1)
            rgb, pts = reduce_pts(color_image, xyz)
            q_bgr_img_out.put(bgr_image, timeout=10)
            q_grasper_in.put([rgb, pts])
            for q in q_cubes_out_dict.values():
                q.put([rgb, pts], timeout = 10)

        except:
            e_stop.set()


def grasper_detection_thread_target(q_in, q_out, e_stop):
    print("Grasper detection thread started")
    obj = {'points': None}
    while not e_stop.is_set():
        try:
            rgb, xyz = q_in.get(timeout=300)
            pts_list = find_pts_grasper(rgb, xyz, grasper_color, grasper_dev, debug=False)  # , debug=True)
            if len(pts_list) > 0:
                pts = max(pts_list, key=len)
                obj = find_grasper(pts, prev_points = obj['points'], debug=False)
                obj['color_name'] = 'grasper'
                obj['color_lab'] = grasper_color
                q_out.put([obj], timeout = 10)
            else:
                obj = {'points': None}
                q_out.put([], timeout = 10)

        except:
            e_stop.set()


def object_detection_thread_target(name, q_in, q_out, e_stop):
    print("Detection thread for {} cube started.".format(name))
    last_object_list = []
    while not e_stop.is_set():
        try:
            rgb, xyz = q_in.get(timeout=300)
            object_list = []
            pts_list = find_pts(rgb, xyz, colors[name], color_devs[name], debug=False)  # , debug=True)
            relevant_prev_objects = [obj for obj in last_object_list if obj['color_name'] == name]
            for pts in pts_list:
                obj = find_object_ransac(pts, name, prev_objects=relevant_prev_objects, debug=False)
                if obj is None:
                    continue
                obj['color_name'] = name
                obj['color_lab'] = colors[name]
                object_list.append(obj)
            q_out.put(object_list, timeout = 10)
            last_object_list = object_list
        except:
            e_stop.set()


def visiualization_thread_target(q_in_bgr_img, q_grasper_out, q_cubes_out_dict, e_stop):
    print("Visualization thread started.")
    n = 0
    while (not e_stop.is_set()):
        try:
            c_img = q_in_bgr_img.get(timeout=300)
            if n == 0:
                t = time.time()
            n+=1
            object_list = []
            object_list.extend(q_grasper_out.get(timeout=300))
            for name in colors.keys():
                object_list.extend(q_cubes_out_dict[name].get(timeout=300))
            for obj in object_list:
                # print(obj)
                c_img = draw_object(obj, c_img)
            cv2.imshow('Objects found', c_img)
            cv2.waitKey(1)
            print("Vision FPS: {}".format(n / (time.time() - t)))
        except:
            e_stop.set()


if __name__ == "__main__":
    q_out_capture = ProcessQueue(maxsize=1)
    e_stop = ProcessEvent()

    q_grasper_in = Queue(maxsize=1)
    q_grasper_out = Queue(maxsize=1)
    q_cubes_in_dict = {name: Queue(maxsize=1) for name in colors.keys()}
    q_cubes_out_dict = {name: Queue(maxsize=10) for name in colors.keys()}
    q_in_visualize_bgr_img = Queue(maxsize=1)

    # capture_thread = Thread(target=capture_from_file_thread_target, args=(q_out_capture, e_stop))
    # capture_thread = Thread(target=capture_thread_target, args=(q_out_capture, e_stop))

    # width = 1280
    # height = 720

    width = 640
    height = 480
    # capture_thread = Process(target=capture_from_file_thread_target, args=(q_out_capture, e_stop))
    # capture_thread = Process(target=capture_thread_target, args=(q_out_capture, e_stop, width, height))
    capture_thread = Process(target=capture_thread_from_bag_target, args=(q_out_capture, e_stop, 'test.bag', width, height))

    capture_thread.start()

    preprocess_thread = Thread(target=preprocess_thread_target, args=(q_out_capture, q_in_visualize_bgr_img, q_grasper_in, q_cubes_in_dict, e_stop))
    preprocess_thread.start()

    grasper_detection_thread = Thread(target=grasper_detection_thread_target, args=(q_grasper_in, q_grasper_out, e_stop))
    grasper_detection_thread.start()

    object_detection_threads = {}
    for name in colors.keys():
        object_detection_threads[name] = Thread(target=object_detection_thread_target, args=(name, q_cubes_in_dict[name], q_cubes_out_dict[name], e_stop))
        object_detection_threads[name].start()

    visualize_thread = Thread(target=visiualization_thread_target, args=(q_in_visualize_bgr_img, q_grasper_out ,q_cubes_out_dict, e_stop))
    visualize_thread.start()

    capture_thread.join()
    preprocess_thread.join()

    grasper_detection_thread.join()
    for name in colors.keys():
        object_detection_threads[name].join()

    visualize_thread.join()